#include "mish_grad_n_d.h"
#include "mish_grad_tanhx_n_d.h"

using namespace MishGrad;
using namespace MishGradTanhx;

extern "C" __global__ __aicore__ void mish_grad(GM_ADDR grad, GM_ADDR x, GM_ADDR tanhx, GM_ADDR y, GM_ADDR workspace,
                                                GM_ADDR tiling) {
    GET_TILING_DATA(tilingData, tiling);

    if (TILING_KEY_IS(101)) {  // without tanhx
        MishGradND<DTYPE_X> op;
        op.Init(grad, x, y, &tilingData);
        op.Process();
    } else if (TILING_KEY_IS(201)) {  // with tanhx
        MishGradTanhxND<DTYPE_X> op;
        op.Init(grad, x, tanhx, y, &tilingData);
        op.Process();
    }
}